{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multi GPU Test\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/MONAI/blob/master/examples/notebooks/multi_gpu_test.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "Note: you may need to restart the kernel to use updated packages.\n"
    }
   ],
   "source": [
    "%pip install -qU \"monai[ignite]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "MONAI version: 0.2.0\nPython version: 3.7.5 (default, Nov  7 2019, 10:50:52)  [GCC 8.3.0]\nNumpy version: 1.19.1\nPytorch version: 1.6.0\n\nOptional dependencies:\nPytorch Ignite version: 0.3.0\nNibabel version: NOT INSTALLED or UNKNOWN VERSION.\nscikit-image version: NOT INSTALLED or UNKNOWN VERSION.\nPillow version: NOT INSTALLED or UNKNOWN VERSION.\nTensorboard version: NOT INSTALLED or UNKNOWN VERSION.\n\nFor details about installing the optional dependencies, please visit:\n    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n\n"
    }
   ],
   "source": [
    "# Copyright 2020 MONAI Consortium\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "\n",
    "import torch\n",
    "\n",
    "from monai.config import print_config\n",
    "from monai.engines import create_multigpu_supervised_trainer\n",
    "from monai.networks.nets import UNet\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test GPUs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-3\n",
    "\n",
    "net = UNet(\n",
    "    dimensions=2,\n",
    "    in_channels=1,\n",
    "    out_channels=1,\n",
    "    channels=(16, 32, 64, 128, 256),\n",
    "    strides=(2, 2, 2, 2),\n",
    "    num_res_units=2,\n",
    ")\n",
    "\n",
    "\n",
    "def fake_loss(y_pred, y):\n",
    "    return (y_pred[0] + y).sum()\n",
    "\n",
    "\n",
    "def fake_data_stream():\n",
    "    while True:\n",
    "        yield torch.rand((10, 1, 64, 64)), torch.rand((10, 1, 64, 64))"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 1 GPU"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "source": [
    "opt = torch.optim.Adam(net.parameters(), lr)\n",
    "trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [torch.device(\"cuda:0\")])\n",
    "trainer.run(fake_data_stream(), 2, 2)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 4,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "State:\n\titeration: 4\n\tepoch: 2\n\tepoch_length: 2\n\tmax_epochs: 2\n\toutput: 40707.8984375\n\tbatch: <class 'tuple'>\n\tmetrics: <class 'dict'>\n\tdataloader: <class 'generator'>\n\tseed: 12"
     },
     "metadata": {},
     "execution_count": 4
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "### all GPUs"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "source": [
    "opt = torch.optim.Adam(net.parameters(), lr)\n",
    "trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, None)\n",
    "trainer.run(fake_data_stream(), 2, 2)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "execution_count": 5,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "State:\n\titeration: 4\n\tepoch: 2\n\tepoch_length: 2\n\tmax_epochs: 2\n\toutput: 35669.37109375\n\tbatch: <class 'tuple'>\n\tmetrics: <class 'dict'>\n\tdataloader: <class 'generator'>\n\tseed: 12"
     },
     "metadata": {},
     "execution_count": 5
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "### CPU"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "source": [
    "opt = torch.optim.Adam(net.parameters(), lr)\n",
    "trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [])\n",
    "trainer.run(fake_data_stream(), 2, 2)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 6,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "State:\n\titeration: 4\n\tepoch: 2\n\tepoch_length: 2\n\tmax_epochs: 2\n\toutput: 29662.359375\n\tbatch: <class 'tuple'>\n\tmetrics: <class 'dict'>\n\tdataloader: <class 'generator'>\n\tseed: 12"
     },
     "metadata": {},
     "execution_count": 6
    }
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
